from d2l import torch as d2l
from torch import nn

from config import *
from transformer import TransformerEncoder, TransformerDecoder, EncoderDecoder
from attention import sequence_mask
from data import loaddata, MyDataset, Lang
import torch.nn.functional as F


def loss_fn(out, tar):
    out = out.view(-1, out.shape[-1])
    tar = tar.view(-1)
    return F.cross_entropy(out, tar, ignore_index=2)  # pad


def train_seq2seq(net, data_iter, lr, num_epochs, lang, device, first_train=True, min_ppl = 1e9):
    """Train a model for sequence to sequence."""

    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])

    if first_train:
        net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)

    for epoch in range(num_epochs):
        net.train()
        Total = 0
        for batch in data_iter:
            optimizer.zero_grad()
            # 转device
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            # inputs带<eos>不带<bos>，outputs带<bos>不带<eos>
            # 初始bos
            bos = torch.tensor([2  # lang.word2index['<bos>']
                                ] * Y.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)
            # Teacher forcing, 输入X和正确答案，希望得到正确答案
            # Y_hat
            Y_hat, _ = net(X, dec_input, X_valid_len)
            # Y_hat和Y中都有eos，没有bos
            l = loss_fn(Y_hat, Y)
            PPL = torch.exp(l)
            l.backward()  # Make the loss scalar for `backward`
            # 梯度裁剪
            # d2l.grad_clipping(net, 1)
            optimizer.step()
            Total += l.item()
        print(Total)

        net.eval()
        PPL_AVG = 0
        with torch.no_grad():
            Total = 0
            for batch in data_iter:
                # 转device
                X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
                # inputs带<eos>不带<bos>，outputs带<bos>不带<eos>
                # 初始bos
                bos = torch.tensor([2  # lang.word2index['<bos>']
                                    ] * Y.shape[0],
                                   device=device).reshape(-1, 1)
                dec_input = torch.cat([bos, Y[:, :-1]], 1)
                # Teacher forcing, 输入X和正确答案，希望得到正确答案
                # Y_hat
                Y_hat, _ = net(X, dec_input, X_valid_len)
                # Y_hat和Y中都有eos，没有bos
                l = loss_fn(Y_hat, Y)
                PPL = torch.exp(l)
                Total += l.item()
                PPL_AVG += PPL.item()
            PPL_AVG /= len(data_iter)
            print(Total, PPL_AVG)
        if PPL_AVG < min_ppl:
            torch.save(encoder, MODEL_ROOT + "trans_encoder.mdl")
            torch.save(decoder, MODEL_ROOT + "trans_decoder.mdl")
            min_ppl = PPL_AVG




if __name__ == "__main__":
    dataset = torch.load(DATA_ROOT + "dataset{}".format(num_examples))
    lang = torch.load(DATA_ROOT + "lang{}".format(num_examples))

    train_iter = loaddata(dataset, 64)
    print(len(lang))

    # 初始化，词表大小，kqv，hidden，
    # encoder = TransformerEncoder(
    #     len(lang), key_size, query_size, value_size, num_hiddens,
    #     norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    #     num_layers, dropout)
    # decoder = TransformerDecoder(
    #     len(lang), key_size, query_size, value_size, num_hiddens,
    #     norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    #     num_layers, dropout)
    encoder = torch.load(MODEL_ROOT + "trans_encoder{}.mdl".format(num_examples))
    decoder = torch.load(MODEL_ROOT + "trans_decoder{}.mdl".format(num_examples))

    net = EncoderDecoder(encoder, decoder)
    train_seq2seq(net, train_iter, lr, num_epochs, lang, device, first_train=False)

    torch.save(encoder, MODEL_ROOT + "trans_encoder{}.mdl".format(num_examples))
    torch.save(decoder, MODEL_ROOT + "trans_decoder{}.mdl".format(num_examples))
